"""
data cue是指数据中的引导解耦的线索。
"""

import argparse
from collections import defaultdict

import numpy as np
import wandb
from disentanglement_lib.data.ground_truth.ground_truth_data import RandomAction
from torch import optim
from torch.utils.data import DataLoader
from tqdm import trange

from disvae.models.anneal import *
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE, init_specific_model
from utils.datasets import get_dataloaders
from disentanglement_lib.data.ground_truth.named_data import get_named_ground_truth_data
from utils.helpers import set_seed
from utils.visualize import *


def get_preds(model, dl):
    model.eval()
    preds = []
    reals = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            reals.append(img)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, torch.cat(reals), targets


def evaluate(model, dl, storer):
    preds, img, targets = get_preds(model, dl)

    kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
    _, index = kl_loss.sort(descending=True)

    if len(index) >= 3:
        points = preds[:1000, index[:3]].cpu()
        colors = targets[:1000]
        fig, axes = plt.subplots(3, 3)
        for i in range(3):  # factor
            for j in range(3):  # variable
                axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                axes[i, j].set_title(f'{i},{index[j]}')
        storer['correlated'] = wandb.Image(fig)

    plt.close()


def train(dl, model, iterations, loss_f):
    opt = optim.Adam(model.parameters(), config.learning_rate, weight_decay=0)
    storer = defaultdict(list)
    for e in trange(iterations // len(dl)):
        for i, (img, _) in enumerate(dl):
            img = img.float().cuda()

            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['loss'].append(loss.item())
            itr = e * len(train_loader) + i
            if itr % 10 == 0:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                storer['itr'] = itr
                # evaluate(model,train_loader,storer)
                wandb.log(storer, sync=False)

                storer = defaultdict(list)
            if itr > iterations:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                return storer
    for k, v in storer.items():
        if isinstance(v, list):
            storer[k] = np.mean(v)
    return storer


for seed in range(6):
    # set_seed(seed)
    # "shapes3d", "cars3d", "smallnorb", "dsprites_full"
    #  "scream_dsprites", "noisy_dsprites", "color_dsprites"
    for ds_name in ["scream_dsprites", "dsprites_full", "color_dsprites", "smallnorb"]:
        hyperparameter_defaults = dict(
            batch_size=1,
            learning_rate=0.0005,
            iterations=10001,
            loss='betaH',
            random_seed=345,
            dataset=ds_name
        )
        dataset = get_named_ground_truth_data(ds_name)

        for a in range(dataset.num_factors):
            action = RandomAction(dataset, a)
            train_loader = DataLoader(action, shuffle=True)
            wandb.init(project="experiment",
                       group=__file__.split('/')[-1][:-3], notes=__doc__,
                       reinit=True, config=hyperparameter_defaults, )
            wandb.config.action = a
            config = wandb.config
            iterations = config.iterations
            dim = 10
            channel = dataset.observation_shape[2]
            vae = init_specific_model("Higgins", (channel, 64, 64), dim, 5)
            vae.cuda()
            vae.train()
            if ds_name in ["smallnorb"]:
                anneal = get_anneal('monotonic', iterations, 50, 1)
            else:
                anneal = get_anneal('monotonic', iterations, 200, 1)
            loss_f = get_loss_s(config.loss, anneal, 1, len(action))
            train(train_loader, vae, iterations, loss_f)

            wandb.join()
